//! Example JWT authorization/authentication.
//!
//! Run with
//!
//! ```not_rust
//! JWT_SECRET=secret cargo run -p example-jwt
//! ```

use axum::{
  async_trait,
  extract::{FromRequest, RequestParts, TypedHeader},
  headers::{authorization::Bearer, Authorization},
  http::StatusCode,
  response::{IntoResponse, Response},
  routing::{get, post},
  Json, Router,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{fmt::Display, net::SocketAddr};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

// Quick instructions
//
// - get an authorization token:
//
// curl -s \
//     -w '\n' \
//     -H 'Content-Type: application/json' \
//     -d '{"client_id":"foo","client_secret":"bar"}' \
//     http://localhost:3000/authorize
//
// - visit the protected area using the authorized token
//
// curl -s \
//     -w '\n' \
//     -H 'Content-Type: application/json' \
//     -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \
//     http://localhost:3000/protected
//
// - try to visit the protected area using an invalid token
//
// curl -s \
//     -w '\n' \
//     -H 'Content-Type: application/json' \
//     -H 'Authorization: Bearer blahblahblah' \
//     http://localhost:3000/protected

static KEYS: Lazy<Keys> = Lazy::new(|| {
  let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
  Keys::new(secret.as_bytes())
});

#[tokio::main]
async fn main() {
  tracing_subscriber::registry()
      .with(tracing_subscriber::EnvFilter::new(
          std::env::var("RUST_LOG").unwrap_or_else(|_| "example_jwt=debug".into()),
      ))
      .with(tracing_subscriber::fmt::layer())
      .init();

  let app = Router::new()
      .route("/protected", get(protected))
      .route("/authorize", post(authorize));

  let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
  tracing::debug!("listening on {}", addr);

  axum::Server::bind(&addr)
      .serve(app.into_make_service())
      .await
      .unwrap();
}

async fn protected(claims: Claims) -> Result<String, AuthError> {
  // Send the protected data to the user
  Ok(format!(
      "Welcome to the protected area :)\nYour data:\n{}",
      claims
  ))
}

async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
  // Check if the user sent the credentials
  if payload.client_id.is_empty() || payload.client_secret.is_empty() {
      return Err(AuthError::MissingCredentials);
  }
  // Here you can check the user credentials from a database
  if payload.client_id != "foo" || payload.client_secret != "bar" {
      return Err(AuthError::WrongCredentials);
  }
  let claims = Claims {
      sub: "b@b.com".to_owned(),
      company: "ACME".to_owned(),
      // Mandatory expiry time as UTC timestamp
      exp: 2000000000, // May 2033
  };
  // Create the authorization token
  let token = encode(&Header::default(), &claims, &KEYS.encoding)
      .map_err(|_| AuthError::TokenCreation)?;

  // Send the authorized token
  Ok(Json(AuthBody::new(token)))
}

impl Display for Claims {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
  }
}

impl AuthBody {
  fn new(access_token: String) -> Self {
      Self {
          access_token,
          token_type: "Bearer".to_string(),
      }
  }
}

#[async_trait]
impl<B> FromRequest<B> for Claims
where
  B: Send,
{
  type Rejection = AuthError;

  async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
      // Extract the token from the authorization header
      let TypedHeader(Authorization(bearer)) =
          TypedHeader::<Authorization<Bearer>>::from_request(req)
              .await
              .map_err(|_| AuthError::InvalidToken)?;
      // Decode the user data
      let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
          .map_err(|_| AuthError::InvalidToken)?;

      Ok(token_data.claims)
  }
}

impl IntoResponse for AuthError {
  fn into_response(self) -> Response {
      let (status, error_message) = match self {
          AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
          AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
          AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
          AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
      };
      let body = Json(json!({
          "error": error_message,
      }));
      (status, body).into_response()
  }
}

struct Keys {
  encoding: EncodingKey,
  decoding: DecodingKey,
}

impl Keys {
  fn new(secret: &[u8]) -> Self {
      Self {
          encoding: EncodingKey::from_secret(secret),
          decoding: DecodingKey::from_secret(secret),
      }
  }
}

#[derive(Debug, Serialize, Deserialize)]
struct Claims {
  sub: String,
  company: String,
  exp: usize,
}

#[derive(Debug, Serialize)]
struct AuthBody {
  access_token: String,
  token_type: String,
}

#[derive(Debug, Deserialize)]
struct AuthPayload {
  client_id: String,
  client_secret: String,
}

#[derive(Debug)]
enum AuthError {
  WrongCredentials,
  MissingCredentials,
  TokenCreation,
  InvalidToken,
}
